{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Quantization-Aware Training with a Keras Transformer Model\n",
    "\n",
    "This notebook shows a working code example of how to use AIMET to perform QAT (Quantization-aware training) for transformer models built in Keras. QAT is an AIMET feature adding quantization simulation ops (also called fake quantization ops sometimes) to a trained ML model and using a standard training pipeline to fine-tune or train the model for a few epochs. The resulting model should show improved accuracy on quantized ML accelerators.\n",
    "\n",
    "#### Overall flow\n",
    "This notebook covers the following\n",
    "1. Load the dataset\n",
    "2. Create the model in Keras\n",
    "3. Train and evaluate the model\n",
    "4. Quantize the model with QuantSim\n",
    "5. Fine-tune the quantized model accuracy with QAT\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "***1. Load the dataset***\n",
    "\n",
    "This notebook relies on the IMDB dataset for sentiment analysis, as provided by Keras."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "\n",
    "vocab_size = 20000  # Only consider the top 20k words\n",
    "maxlen = 200  # Only consider the first 200 words of each movie review\n",
    "\n",
    "(x_train, y_train), (x_val, y_val) = keras.datasets.imdb.load_data(num_words=vocab_size)\n",
    "print(len(x_train), \"Training sequences\")\n",
    "print(len(x_val), \"Validation sequences\")\n",
    "\n",
    "x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)\n",
    "x_val = keras.preprocessing.sequence.pad_sequences(x_val, maxlen=maxlen)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Currently, only Keras models built using the Sequential or Functional APIs are compatible with QuantSim - models making use of subclassed layers are incompatible. Therefore, we use the Functional API to create the model used in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "embed_dim = 32  # Embedding size for each token\n",
    "num_heads = 2  # Number of attention heads\n",
    "ff_dim = 32  # Hidden layer size in feed forward network inside transformer\n",
    "\n",
    "############## FUNCTIONAL MODEL ##############\n",
    "inputs = layers.Input(shape=(maxlen,))\n",
    "\n",
    "# Embedding Layer\n",
    "positions = tf.range(start=0, limit=maxlen, delta=1)\n",
    "positions = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)(positions)\n",
    "x = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)(inputs)\n",
    "x = x + positions\n",
    "\n",
    "# Transformer Block\n",
    "x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)(x, x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "x = layers.LayerNormalization(epsilon=1e-6)(x)\n",
    "x = layers.Dense(ff_dim, activation=\"relu\")(x)\n",
    "x = layers.Dense(embed_dim)(x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "x = layers.LayerNormalization(epsilon=1e-6)(x)\n",
    "\n",
    "# Output layers\n",
    "x = layers.GlobalAveragePooling1D()(x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "x = layers.Dense(20, activation=\"relu\")(x)\n",
    "x = layers.Dropout(0.1)(x)\n",
    "outputs = layers.Dense(2, activation=\"softmax\")(x)\n",
    "################################################\n",
    "\n",
    "functional_model = keras.Model(inputs=inputs, outputs=outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "---\n",
    "***3. Train and evaluate the model to get a baseline accuracy***\n",
    "\n",
    "Before we can quantize the model and apply QAT, the FP32 model must be trained so that we can get a baseline accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "functional_callback = tf.keras.callbacks.TensorBoard(log_dir=\"./log/functional\", histogram_freq=1)\n",
    "functional_model.compile(optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
    "history = functional_model.fit(\n",
    "    x_train, y_train, batch_size=32, epochs=1, validation_data=(x_val, y_val), callbacks=[functional_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "# Evaluate the model on the test data using `evaluate`\n",
    "print(\"Evaluate model on test data\")\n",
    "results = functional_model.evaluate(x_val, y_val, batch_size=128)\n",
    "print(\"test loss, test acc:\", results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "---\n",
    "***4. Create a QuantizationSim Model and determine quantized accuracy***\n",
    "\n",
    "**Create Quantization Sim Model**\n",
    "\n",
    "Now we use AIMET to create a QuantizationSimModel. This basically means that AIMET will insert fake quantization ops in the model graph and will configure them.\n",
    "A few of the parameters are explained here\n",
    "- **quant_scheme**: We set this to \"QuantScheme.post_training_tf_enhanced\"\n",
    "    - Supported options are 'tf_enhanced' or 'tf' or using Quant Scheme Enum QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced\n",
    "- **default_output_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all activation quantizations in the model using integer 8-bit precision\n",
    "- **default_param_bw**: Setting this to 8, essentially means that we are asking AIMET to perform all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "There are other parameters that are set to default values in this example. Please check the AIMET API documentation of QuantizationSimModel to see reference documentation for all the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from aimet_common.defs import QuantScheme\n",
    "from aimet_tensorflow.keras.quantsim import QuantizationSimModel\n",
    "\n",
    "model = QuantizationSimModel(model=functional_model,\n",
    "                             quant_scheme=QuantScheme.post_training_tf_enhanced,\n",
    "                             rounding_mode='nearest',\n",
    "                             default_output_bw=8,\n",
    "                             default_param_bw=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "QuantSim works by wrapping each layer in the model with a Quantization Wrapper that simulates the effects of quantization on the inputs, outputs, and parameters of the layer (visualized below). A regular Conv2D Keras layer is displayed on the right, while a Conv2D layer after a quantization wrapper has been applied is displayed on the left.\n",
    "![A regular Conv2d layer](../images/keras_pre_quant_layer.png)  &nbsp;&nbsp;   ![A Conv2D layer after a quantization wrapper has been applied](../images/keras_post_quant_layer.png)\n",
    "\n",
    "If a multi-head attention layer is encountered in the model, the original layer is replaced with a custom quantizable version that gives the QuantizationSimModel access to the inputs and outputs of internal ops within the layer, so that quantization wrappers can be applied at a more granular level than the entire MHA layer. This is necessary in order to accurately simulate the effects of on-target quantization.\n",
    "\n",
    "This works by making use of Keras's built-in `clone_layer` function, which allows us to clone and modify the FP32 model layer by layer. A more detailed call flow diagram is displayed below.\n",
    "![Keras QuantSim call flow](../images/keras_quantsim_callflow.png)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "---\n",
    "**Compute Encodings**\n",
    "\n",
    "Even though AIMET has added 'quantizer' nodes to the model graph but the model is not ready to be used yet. Before we can use the sim model for inference or training, we need to find appropriate scale/offset quantization parameters for each 'quantizer' node. For activation quantization nodes, we need to pass unlabeled data samples through the model to collect range statistics which will then let AIMET calculate appropriate scale/offset quantization parameters. This process is sometimes referred to as calibration. AIMET simply refers to it as 'computing encodings'.\n",
    "\n",
    "So we create a routine to pass unlabeled data samples through the model. This should be fairly simple - use the existing train or validation data loader to extract some samples and pass them to the model. We don't need to compute any loss metric etc. So we can just ignore the model output for this purpose. A few pointers regarding the data samples\n",
    "- In practice, we need a very small percentage of the overall data samples for computing encodings.\n",
    "- It may be beneficial if the samples used for computing encoding are well distributed. It's not necessary that all classes need to be covered etc. since we are only looking at the range of values at every layer activation. However, we definitely want to avoid an extreme scenario like all positive or negative samples are used.\n",
    "\n",
    "The following shows an example of a routine that passes unlabeled samples through the model for computing encodings. This routine can be written in many different ways, this is just an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.compute_encodings(lambda m, _: m(x_val[0:1000]), None)\n",
    "model.export('./data', 'model', convert_to_pb=False) # Once the encodings have been computed, export them for later inspection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Next, we can evaluate the performance of the quantized model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.model.compile(optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]) # must compile model before evaluating\n",
    "\n",
    "print(\"Evaluate quantized model on test data\")\n",
    "results = model.model.evaluate(x_val, y_val, batch_size=128)\n",
    "print(\"test loss, test acc:\", results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "---\n",
    "***5. Perform QAT***\n",
    "\n",
    "To perform quantization aware training (QAT), we simply train the model for a few more epochs (typically 15-20). As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.\n",
    "For the purpose of this example notebook, we are going to train only for 1 epoch. But feel free to change these parameters as you see fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "quantized_callback = tf.keras.callbacks.TensorBoard(log_dir=\"./log/quantized\")\n",
    "history = model.model.fit(\n",
    "    x_train[0:1024], y_train[0:1024], batch_size=32, epochs=1, validation_data=(x_val, y_val), callbacks=[quantized_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Now, let's compute and export the encodings of the model after performing QAT. When comparing the encodings file generated by this step and the encodings generated before quantization, there should be some differences. These differences are an artifact of QAT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.compute_encodings(lambda m, _: m(x_val[0:3000]), None)\n",
    "model.export('./data', 'model_after_qat', convert_to_pb=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Finally, let's evaluate the validation accuracy of our model after QAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "print(\"Evaluate quantized model (post QAT) on test data\")\n",
    "results = model.model.evaluate(x_val, y_val, batch_size=128)\n",
    "print(\"test loss, test acc:\", results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "We can also use tensorboard to visualize the FP32 and quantized models to see how they are different from one another. Comparing the two, we can see that most layers are now replaced with a quantization wrapped simulating the effects of quantization at the input and output nodes of the layer. In the case of more complex layers, like multi-head attention, QuantSim has custom pipelines to insert quantization wrappers around more elementary ops within the layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "%tensorboard --logdir logs\n",
    "from tensorboard import notebook\n",
    "\n",
    "notebook.display(height=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "---\n",
    "***Summary***\n",
    "\n",
    "Hope this notebook was useful for you to understand how to use AIMET with Keras models.\n",
    "Few additional resources\n",
    "Refer to the AIMET API docs to know more details of the APIs and optional parameters\n",
    "Refer to the other example notebooks to understand how to use AIMET post-training quantization techniques and the vanilla QAT method (without range-learning)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
